Source code for nlp_architect.models.gnmt.utils.evaluation_utils

# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
# Changes Made from original:
#   import paths
# ************************nlp_architect.models.gnmt******************************************************
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: skip-file
"""Utility for evaluating various tasks, e.g., translation & summarization."""
import codecs
import os
import re
import subprocess

import tensorflow as tf

from nlp_architect.models.gnmt.scripts import bleu
from nlp_architect.models.gnmt.scripts import rouge

__all__ = ["evaluate"]


[docs]def evaluate(ref_file, trans_file, metric, subword_option=None): """Pick a metric and evaluate depending on task.""" # BLEU scores for translation task if metric.lower() == "bleu": evaluation_score = _bleu(ref_file, trans_file, subword_option=subword_option) # ROUGE scores for summarization tasks elif metric.lower() == "rouge": evaluation_score = _rouge(ref_file, trans_file, subword_option=subword_option) elif metric.lower() == "accuracy": evaluation_score = _accuracy(ref_file, trans_file) elif metric.lower() == "word_accuracy": evaluation_score = _word_accuracy(ref_file, trans_file) else: raise ValueError("Unknown metric %s" % metric) return evaluation_score
def _clean(sentence, subword_option): """Clean and handle BPE or SPM outputs.""" sentence = sentence.strip() # BPE if subword_option == "bpe": sentence = re.sub("@@ ", "", sentence) # SPM elif subword_option == "spm": sentence = u"".join(sentence.split()).replace(u"\u2581", u" ").lstrip() return sentence # Follow //transconsole/localization/machine_translation/metrics/bleu_calc.py def _bleu(ref_file, trans_file, subword_option=None): """Compute BLEU scores and handling BPE.""" max_order = 4 smooth = False ref_files = [ref_file] reference_text = [] for reference_filename in ref_files: with codecs.getreader("utf-8")( tf.gfile.GFile(reference_filename, "rb")) as fh: reference_text.append(fh.readlines()) per_segment_references = [] for references in zip(*reference_text): reference_list = [] for reference in references: reference = _clean(reference, subword_option) reference_list.append(reference.split(" ")) per_segment_references.append(reference_list) translations = [] with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh: for line in fh: line = _clean(line, subword_option=None) translations.append(line.split(" ")) # bleu_score, precisions, bp, ratio, translation_length, reference_length bleu_score, _, _, _, _, _ = bleu.compute_bleu( per_segment_references, translations, max_order, smooth) return 100 * bleu_score def _rouge(ref_file, summarization_file, subword_option=None): """Compute ROUGE scores and handling BPE.""" references = [] with codecs.getreader("utf-8")(tf.gfile.GFile(ref_file, "rb")) as fh: for line in fh: references.append(_clean(line, subword_option)) hypotheses = [] with codecs.getreader("utf-8")( tf.gfile.GFile(summarization_file, "rb")) as fh: for line in fh: hypotheses.append(_clean(line, subword_option=None)) rouge_score_map = rouge.rouge(hypotheses, references) return 100 * rouge_score_map["rouge_l/f_score"] def _accuracy(label_file, pred_file): """Compute accuracy, each line contains a label.""" with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "rb")) as label_fh: with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "rb")) as pred_fh: count = 0.0 match = 0.0 for label in label_fh: label = label.strip() pred = pred_fh.readline().strip() if label == pred: match += 1 count += 1 return 100 * match / count def _word_accuracy(label_file, pred_file): """Compute accuracy on per word basis.""" with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "r")) as label_fh: with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "r")) as pred_fh: total_acc, total_count = 0., 0. for sentence in label_fh: labels = sentence.strip().split(" ") preds = pred_fh.readline().strip().split(" ") match = 0.0 for pos in range(min(len(labels), len(preds))): label = labels[pos] pred = preds[pos] if label == pred: match += 1 total_acc += 100 * match / max(len(labels), len(preds)) total_count += 1 return total_acc / total_count def _moses_bleu(multi_bleu_script, tgt_test, trans_file, subword_option=None): """Compute BLEU scores using Moses multi-bleu.perl script.""" # TODO(thangluong): perform rewrite using python # BPE if subword_option == "bpe": debpe_tgt_test = tgt_test + ".debpe" if not os.path.exists(debpe_tgt_test): # TODO(thangluong): not use shell=True, can be a security hazard cp_cmd = "cp %s %s" % (tgt_test, debpe_tgt_test) subprocess.call(cp_cmd.split(), shell=False) sed_cmd = "sed s/@@ //g %s" % (debpe_tgt_test) subprocess.call(sed_cmd.split(), shell=False) tgt_test = debpe_tgt_test elif subword_option == "spm": despm_tgt_test = tgt_test + ".despm" if not os.path.exists(despm_tgt_test): subprocess.call("cp %s %s" % (tgt_test, despm_tgt_test)) subprocess.call("sed s/ //g %s" % (despm_tgt_test)) subprocess.call(u"sed s/^\u2581/g %s" % (despm_tgt_test)) subprocess.call(u"sed s/\u2581/ /g %s" % (despm_tgt_test)) tgt_test = despm_tgt_test cmd = "%s %s < %s" % (multi_bleu_script, tgt_test, trans_file) # subprocess # TODO(thangluong): not use shell=True, can be a security hazard bleu_output = subprocess.check_output(cmd.split(), shell=False) # extract BLEU score m = re.search("BLEU = (.+?),", bleu_output) bleu_score = float(m.group(1)) return bleu_score